{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformers installation\n",
    "! pip install transformers datasets evaluate accelerate\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# ! pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Subtítulos de Imágenes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Los subtítulos de imágenes es la tarea de predecir un subtítulo para una imagen dada. Las aplicaciones comunes en el mundo real incluyen\n",
    "ayudar a personas con discapacidad visual que les puede ayudar a navegar a través de diferentes situaciones. Por lo tanto, los subtítulos de imágenes\n",
    "ayuda a mejorar la accesibilidad del contenido para las personas describiéndoles imágenes.\n",
    "\n",
    "Esta guía te mostrará cómo:\n",
    "\n",
    "* Ajustar un modelo de subtítulos de imágenes.\n",
    "* Usar el modelo ajustado para inferencia.\n",
    "\n",
    "Antes de comenzar, asegúrate de tener todas las bibliotecas necesarias instaladas:\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate -q\n",
    "pip install jiwer -q\n",
    "```\n",
    "\n",
    "Te animamos a que inicies sesión en tu cuenta de Hugging Face para que puedas subir y compartir tu modelo con la comunidad. Cuando se te solicite, ingresa tu token para iniciar sesión:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cargar el conjunto de datos de subtítulos BLIP de Pokémon"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Utiliza la biblioteca 🤗 Dataset para cargar un conjunto de datos que consiste en pares {image-caption}. Para crear tu propio conjunto de datos de subtítulos de imágenes\n",
    "en PyTorch, puedes seguir [este cuaderno](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"lambdalabs/pokemon-blip-captions\")\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "DatasetDict({\n",
    "    train: Dataset({\n",
    "        features: ['image', 'text'],\n",
    "        num_rows: 833\n",
    "    })\n",
    "})\n",
    "```\n",
    "\n",
    "El conjunto de datos tiene dos características, `image` y `text`.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "Muchos conjuntos de datos de subtítulos de imágenes contienen múltiples subtítulos por imagen. En esos casos, una estrategia común es muestrear aleatoriamente un subtítulo entre los disponibles durante el entrenamiento.\n",
    "\n",
    "</Tip>\n",
    "\n",
    "Divide el conjunto de entrenamiento del conjunto de datos en un conjunto de entrenamiento y de prueba con el método `train_test_split`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ds[\"train\"].train_test_split(test_size=0.1)\n",
    "train_ds = ds[\"train\"]\n",
    "test_ds = ds[\"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Vamos a visualizar un par de muestras del conjunto de entrenamiento."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from textwrap import wrap\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def plot_images(images, captions):\n",
    "    plt.figure(figsize=(20, 20))\n",
    "    for i in range(len(images)):\n",
    "        ax = plt.subplot(1, len(images), i + 1)\n",
    "        caption = captions[i]\n",
    "        caption = \"\\n\".join(wrap(caption, 12))\n",
    "        plt.title(caption)\n",
    "        plt.imshow(images[i])\n",
    "        plt.axis(\"off\")\n",
    "\n",
    "\n",
    "sample_images_to_visualize = [np.array(train_ds[i][\"image\"]) for i in range(5)]\n",
    "sample_captions = [train_ds[i][\"text\"] for i in range(5)]\n",
    "plot_images(sample_images_to_visualize, sample_captions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png\" alt=\"Sample training images\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocesar el conjunto de datos"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dado que el conjunto de datos tiene dos modalidades (imagen y texto), el proceso de preprocesamiento preprocesará las imágenes y los subtítulos.\n",
    "\n",
    "Para hacerlo, carga la clase de procesador asociada con el modelo que estás a punto de ajustar."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "checkpoint = \"microsoft/git-base\"\n",
    "processor = AutoProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "El procesador preprocesará internamente la imagen (lo que incluye el cambio de tamaño y la escala de píxeles) y tokenizará el subtítulo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transforms(example_batch):\n",
    "    images = [x for x in example_batch[\"image\"]]\n",
    "    captions = [x for x in example_batch[\"text\"]]\n",
    "    inputs = processor(images=images, text=captions, padding=\"max_length\")\n",
    "    inputs.update({\"labels\": inputs[\"input_ids\"]})\n",
    "    return inputs\n",
    "\n",
    "\n",
    "train_ds.set_transform(transforms)\n",
    "test_ds.set_transform(transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Con el conjunto de datos listo, ahora puedes configurar el modelo para el ajuste fino."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cargar un modelo base"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Carga [\"microsoft/git-base\"](https://huggingface.co/microsoft/git-base) en un objeto [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Los modelos de subtítulos de imágenes se evalúan típicamente con el [Rouge Score](https://huggingface.co/spaces/evaluate-metric/rouge) o Tasa de Error de Palabra ([Word Error Rate](https://huggingface.co/spaces/evaluate-metric/wer), por sus siglas en inglés). Para esta guía, utilizarás la Tasa de Error de Palabra (WER).\n",
    "\n",
    "Usamos la biblioteca 🤗 Evaluate para hacerlo. Para conocer las limitaciones potenciales y otros problemas del WER, consulta [esta guía](https://huggingface.co/spaces/evaluate-metric/wer)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluate import load\n",
    "import torch\n",
    "\n",
    "wer = load(\"wer\")\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predicted = logits.argmax(-1)\n",
    "    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)\n",
    "    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)\n",
    "    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)\n",
    "    return {\"wer_score\": wer_score}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ¡Entrenar!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ahora, estás listo para comenzar a ajustar el modelo. Utilizarás el 🤗 `Trainer` para esto.\n",
    "\n",
    "Primero, define los argumentos de entrenamiento usando `TrainingArguments`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "model_name = checkpoint.split(\"/\")[1]\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"{model_name}-pokemon\",\n",
    "    learning_rate=5e-5,\n",
    "    num_train_epochs=50,\n",
    "    fp16=True,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    gradient_accumulation_steps=2,\n",
    "    save_total_limit=3,\n",
    "    eval_strategy=\"steps\",\n",
    "    eval_steps=50,\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=50,\n",
    "    logging_steps=50,\n",
    "    remove_unused_columns=False,\n",
    "    push_to_hub=True,\n",
    "    label_names=[\"labels\"],\n",
    "    load_best_model_at_end=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Luego pásalos junto con los conjuntos de datos y el modelo al 🤗 Trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=test_ds,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Para comenzar el entrenamiento, simplemente llama a `train()` en el objeto `Trainer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deberías ver cómo disminuye suavemente la pérdida de entrenamiento a medida que avanza el entrenamiento.\n",
    "\n",
    "Una vez completado el entrenamiento, comparte tu modelo en el Hub con el método `push_to_hub()` para que todos puedan usar tu modelo:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inferencia"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Toma una imagen de muestra de test_ds para probar el modelo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "\n",
    "url = \"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png\" alt=\"Test image\"/>\n",
    "</div>\n",
    "\n",
    "Prepara la imagen para el modelo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "inputs = processor(images=image, return_tensors=\"pt\").to(device)\n",
    "pixel_values = inputs.pixel_values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Llama a `generate` y decodifica las predicciones."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_ids = model.generate(pixel_values=pixel_values, max_length=50)\n",
    "generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "print(generated_caption)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "a drawing of a pink and blue pokemon\n",
    "```\n",
    "\n",
    "¡Parece que el modelo ajustado generó un subtítulo bastante bueno!"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
